from torch import nn


class FeedForward(nn.Module):
    """前馈网络"""
    def __init__(self, dim, mult=4, dropout=0.):
        super().__init__()
        dim_inner = int(dim * mult)
        self.net = nn.Sequential(
            nn.Linear(dim, dim_inner),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim_inner, dim)
        )

    def forward(self, x):
        return self.net(x)